Eduardo Blancas

Source: scholarpedia.org
Grouping of spikes into clusters based on the similarity of their shapes. Given that, each neuron tends to fire spikes of a particular shape, the resulting clusters correspond to the activity of different neurons. The end result of spike sorting is the determination of which spike corresponds to which of these neurons. (Quiroga, 2007)
Most projects involving analysis of neural data involve spike sorting as the first step, the analysis continues once the number of neurons and spike times have been determined.
A crucial consideration going forward is the ability to scale to massive datasets – MEAs (Multiple Electrode Arrays) currently scale up to the order of 10^4 electrodes, but efforts are underway to increase this number to 10^6 (Lee, 2017).
This projects focuses on the clustering step, once spikes have been detected, waveforms are extracted around the spike. Temporal dimensionality is reduced by using an autoencoder (31 to 3 dimensions), for this project we are considering all 7 spatial dimensions (one per channel) in the data, but we are currently working on scaling up clustering by only considering neighboring channels.
I will be specific vocabulary to refer to the data and the results:
Quiroga, R. (2007) Spike Sorting. http://www.scholarpedia.org/article/Spike_sorting
Lee, J. et al. (2017). YASS: Yet another spike sorter. Neural Information Processing Systems. Available in biorxiv: https://www.biorxiv.org/content/early/2017/06/19/151928
This notebook loads the raw data and the pre-processed files generated by the YASS library, we will use one of those files as input for the clustering algorithm.
%matplotlib inline
import logging
import os
import numpy as np
from yass.neuralnet import NeuralNetDetector
from yass.config import Config
from neural_clustering.explore import (SpikeTrainExplorer,
RecordingExplorer)
from neural_clustering import config
import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (10, 10)
logging.basicConfig(level=logging.ERROR)
YASS is a Python package for spike sorting, which is being developed by Peter Lee (PhD in the Stats department and me): https://github.com/paninski-lab/yass
Someone in the lab implemented a truncated DPMM using numpy. Since the code is hard to debug and the only person who understands it is the person who wrote it, I want to see if we can start using Edward instead, so we can iterate quickly and prototype new models easily – without having to write custom inference algorithms every time.
# load configuration files
cfg_yass = Config.from_yaml('../yass_config/demo.yaml')
cfg = config.load('../config.yaml')
# load data generated from yass, we are only interested
# in spike_times and clear_index but we need to load all of them
# to instantiate the explorers that implement the functions we will
# use
files = ['score', 'clear_index', 'spike_times',
'spike_train', 'spike_left', 'templates']
(score, clear_index,
spike_times, spike_train,
spike_left, templates) = [np.load(os.path.join(cfg['root'], 'yass/{}.npy'.format(f))) for f in files]
We load raw data, standarized data and channel geometry (spatial location of electrodes)
# load raw recordings
path_to_raw_recordings = os.path.join(cfg_yass.root, '7ch.bin')
# load standarized recordings (these are raw recordings + filter + standarization)
path_to_recordings = os.path.join(cfg_yass.root, 'tmp/standarized.bin')
# load gemetry file (position for every electro)
path_to_geometry = os.path.join(cfg_yass.root, cfg_yass.geomFile)
# load projection matrix (to reduce dimensionality)
proj = NeuralNetDetector(cfg_yass).load_w_ae()
This helper classes contain functions to plot the data.
# initialize explorers, these objects implement functions for plotting
# the output from YASS
explorer_rec = RecordingExplorer(path_to_recordings,
path_to_geometry,
dtype='float64',
window_size=cfg_yass.spikeSize,
n_channels=cfg_yass.nChan,
neighbor_radius=cfg_yass.spatialRadius)
explorer_raw = RecordingExplorer(path_to_raw_recordings,
path_to_geometry,
dtype='int16',
window_size=cfg_yass.spikeSize,
n_channels=cfg_yass.nChan,
neighbor_radius=cfg_yass.spatialRadius)
explorer_train = SpikeTrainExplorer(templates,
spike_train,
explorer_rec,
proj)
print('Observations: {}. Channels: {}'.format(*explorer_raw.data.shape))
The timeseries plot shows how the raw data looks like in every channel.
plt.rcParams['figure.figsize'] = (60, 60)
explorer_raw.plot_series(from_time=4500, to_time=5000)
The timeseries plot shows how the filtered + standarized data looks like in every channel.
explorer_rec.plot_series(from_time=4500, to_time=5000)
Electoreds spatial location plot.
plt.rcParams['figure.figsize'] = (10, 10)
explorer_rec.plot_geometry()
To quickly iterate over models, we are only training using spikes from channel 0.
clear_indexes = clear_index[0]
clear_spikes = spike_times[0][clear_indexes, 0]
all_spike_times = np.vstack(spike_times)[:, 0]
print('Detected {} clear spikes'.format(clear_spikes.shape[0]))
# there is a bug in yass 0.1.1 that is shifting spike times
clear_spikes = clear_spikes - cfg_yass.BUFF
clear_spikes
Plot for a single detected spike.
plt.rcParams['figure.figsize'] = (10, 15)
t = clear_spikes[0]
explorer_rec.plot_waveform(time=t, channels=range(7))
Once we hace spike times, we need to load a time window around it, we load 15 observations before and after, so we have 31 temporal observations per spike.
waveforms = explorer_rec.read_waveforms(times=clear_spikes)
print('Training set dimensions: {}'.format(waveforms.shape))
We reduce temporal dimensionality (from 31 to 3) in each channel using an autoencoder, so we get 3 * n_channels = 21 features for the training data.
waveforms_reduced = explorer_train._reduce_dimension(waveforms, flatten=True)
print('Training set dimensions: {}'.format(waveforms_reduced.shape))
We save the training data and clear spike times (we will later use the spike times for visualization)
output_path = os.path.join(cfg['root'], 'training.npy')
np.save(output_path, waveforms_reduced)
print(f'Saved training data in {output_path}')
output_path = os.path.join(cfg['root'], 'clear_spikes.npy')
np.save(output_path, clear_spikes)
print(f'Saved clear spike times in {output_path}')
This notebook shows how we can train models and list previously trained ones (along with useful information about them).
Important: this is just a demo to show how models are trained, due to the way Edward manages Tensorflow sessions, each model should be trained in a different Python session, otherwise restoring it (notebooks 4-1 and 4-2 will throw errors.
import os
import numpy as np
from neural_clustering.model import dpmm, gmm, util
from neural_clustering.criticize import summarize_experiments
from neural_clustering import config
import logging
logging.basicConfig(level=logging.INFO)
cfg = config.load('../config.yaml')
x_train = np.load(os.path.join(cfg['root'], 'training.npy'))
print(f'x_train shape: {x_train.shape}')
# x_train = util.build_toy_dataset(500)
# print(f'x_train shape: {x_train.shape}')
$$p(x_n | \pi, \mu, \sigma) = \sum_{k=1}^{K} \pi_k \mathrm{Normal}(x_n |\; \mu_k, \sigma_k)$$
$$ \beta_k \sim \mathrm{Beta}(1,\alpha) $$
$$ \pi_i = \beta_i \prod_{j=1}^{i-1}(1-\beta_j) $$
$$\mu_k \sim \mathrm{Normal} (\mu_k |\; \mathbf{0}, \mathbf{I}) $$
$$\sigma_k^2 \sim \mathrm{Gamma}(\sigma^2_k |\; a, b) $$
There is a bug in Edward that sometimes throws an error when fitting, it occurs more often with truncation levels > 5. It has been reported in discourse already
# small number of iterations for testing
dpmm.fit(x_train, truncation_level=3, cfg=cfg,
inference_params=dict(n_iter=500))
$$p(x_n | \pi, \mu, \sigma) = \sum_{k=1}^{K} \pi_k \mathrm{Normal}(x_n |\; \mu_k, \sigma_k)$$
$$\pi \sim \mathrm{Dirichlet}(\pi, \alpha \mathbf{1}_K) $$
$$\mu_k \sim \mathrm{Normal} (\mu_k |\; \mathbf{0}, \mathbf{I}) $$
$$\sigma_k^2 \sim \mathrm{Gamma}(\sigma^2_k |\; a, b) $$
# small number of iterations for testing
gmm.fit(x_train, k=13, cfg=cfg, samples=10)
List all previously trained models along with some useful information.
summarize_experiments(cfg)
Checking for convergence (GMM is trained using Gibbs), cluster sizes and Edward's PPC plots.
Important: restart the kernel when changing the experiment to evaluate
EXPERIMENT_NAME = '30-Nov-2017@16-53-17-GMM'
%matplotlib inline
import logging
import os
import numpy as np
import edward as ed
import tensorflow as tf
import matplotlib.pyplot as plt
from neural_clustering.criticize import (plot, restore,
store_cluster_assignments,
ppc_plot,
summarize_experiment)
from neural_clustering import config
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (20, 20)
logging.basicConfig(level=logging.INFO)
cfg = config.load('../../config.yaml')
exp = restore.experiment(cfg, EXPERIMENT_NAME)
x_pred = exp['x_pred']
x_train = exp['x_train'].astype('float32')
summarize_experiment(cfg, EXPERIMENT_NAME)
Proportion parameters over every iteration.
plot.params_over_iterations(exp['qpi'], axis=1, sharex=False)
Mixture means over every iteration.
plot.params_over_iterations(exp['qmu'], axis=1, sharex=False)
Find cluster assignments and save them, they will later be used in notebook 5.
clusters = store_cluster_assignments(cfg, exp['x_train'], exp['qmu'], exp['params'])
plt.rcParams['figure.figsize'] = (15, 5)
plot.cluster_counts(clusters)
Evaluate log likelihood and mean squared error.
log_lik = ed.evaluate('log_likelihood', data={x_pred: x_train})
mse = ed.evaluate('mean_squared_error', data={x_pred: x_train})
print(f'Log likelihood is: {log_lik:0.2f}')
print(f'Mean squared error is: {mse:0.2f}')
PPC plots for mean, max and min.
ppc_plot(lambda xs, mus: tf.reduce_mean(xs[x_pred]), 'mean', x_pred, x_train)
ppc_plot(lambda xs, mus: tf.reduce_max(xs[x_pred]), 'max', x_pred, x_train)
ppc_plot(lambda xs, mus: tf.reduce_min(xs[x_pred]), 'min', x_pred, x_train)
In this notebook, we visually inspect the clustering assignments.
EXPERIMENT_NAME = '30-Nov-2017@16-53-17-GMM'
%matplotlib inline
import logging
import os
import numpy as np
from yass.neuralnet import NeuralNetDetector
from yass.config import Config
from neural_clustering.criticize import summarize_experiment
from neural_clustering.explore import (SpikeTrainExplorer,
RecordingExplorer)
from neural_clustering import config
import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (15, 15)
logging.basicConfig(level=logging.ERROR)
# load configuration files
cfg_yass = Config.from_yaml('../../yass_config/demo.yaml')
cfg = config.load('../../config.yaml')
# load data generated from yass
files = ['score', 'clear_index', 'spike_times', 'spike_train', 'spike_left', 'templates']
(score, clear_index,
spike_times, spike_train,
spike_left, templates) = [np.load(os.path.join(cfg['root'], 'yass/{}.npy'.format(f))) for f in files]
summarize_experiment(cfg, EXPERIMENT_NAME)
# load standarized recordings (these are raw recordings + filter + standarization)
path_to_recordings = os.path.join(cfg_yass.root, 'tmp/standarized.bin')
# load gemetry file (position for every electro)
path_to_geometry = os.path.join(cfg_yass.root, cfg_yass.geomFile)
# load projection matrix (to reduce dimensionality)
proj = NeuralNetDetector(cfg_yass).load_w_ae()
We load the clustering assignments from the experiment
clusters = np.load(os.path.join(cfg['root'], 'sessions', EXPERIMENT_NAME, 'clusters.npy'))
clear_spikes = np.load(os.path.join(cfg['root'], 'clear_spikes.npy'))
results = np.vstack([clear_spikes, clusters]).T
group_ids = np.unique(clusters)
These helper classes contain several functions for plotting results.
# initialize explorers, these objects implement functions for plotting
# the output from YASS
explorer_rec = RecordingExplorer(path_to_recordings,
path_to_geometry,
dtype='float64',
window_size=cfg_yass.spikeSize,
n_channels=cfg_yass.nChan,
neighbor_radius=cfg_yass.spatialRadius)
explorer_train = SpikeTrainExplorer(templates,
results,
explorer_rec,
proj)
Build templates from clusters: take every point in each cluster, get the original waveform (31 temporal observations x 7 channels) and average all elements. Then plot the templates.
def make_template(group_id):
wfs = explorer_train.waveforms_for_group(group_id=group_id, channels=range(7))
return np.average(wfs, axis=0)
templates_new = np.stack([make_template(group_id) for group_id in group_ids]).transpose(2, 1, 0)
explorer_train = SpikeTrainExplorer(templates_new,
results,
explorer_rec,
proj)
plt.rcParams['figure.figsize'] = (15, 15)
explorer_train.plot_templates(group_ids=group_ids)
Plot nehgboring clusters: get all the templates, compute the similarity among them (by comparing the squared difference in each temporal observation in each channel)
Then for every template, get the two most similar templates, find the cluster id for each and get the cluster elements for that cluster id.
Project all the points in the three clusters using LDA and plot.
plt.rcParams['figure.figsize'] = (15, 15)
explorer_train.plot_all_clusters(k=3)
Plot 3 similar clusters along with the template for each group.
explorer_train.plot_waveforms_and_clusters(group_id=0)
Checking for convergence (GMM is trained using Gibbs), cluster sizes and Edward's PPC plots.
Important: restart the kernel when changing the experiment to evaluate
EXPERIMENT_NAME = '30-Nov-2017@16-58-38-GMM'
%matplotlib inline
import logging
import os
import numpy as np
import edward as ed
import tensorflow as tf
import matplotlib.pyplot as plt
from neural_clustering.criticize import (plot, restore,
store_cluster_assignments,
ppc_plot,
summarize_experiment)
from neural_clustering import config
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (20, 20)
logging.basicConfig(level=logging.INFO)
cfg = config.load('../../config.yaml')
exp = restore.experiment(cfg, EXPERIMENT_NAME)
x_pred = exp['x_pred']
x_train = exp['x_train'].astype('float32')
summarize_experiment(cfg, EXPERIMENT_NAME)
Proportion parameters over every iteration.
plot.params_over_iterations(exp['qpi'], axis=1, sharex=False)
Mixture means over every iteration.
plot.params_over_iterations(exp['qmu'], axis=1, sharex=False)
Find cluster assignments and save them, they will later be used in notebook 5.
clusters = store_cluster_assignments(cfg, exp['x_train'], exp['qmu'], exp['params'])
plt.rcParams['figure.figsize'] = (15, 5)
plot.cluster_counts(clusters)
Evaluate log likelihood and mean squared error.
log_lik = ed.evaluate('log_likelihood', data={x_pred: x_train})
mse = ed.evaluate('mean_squared_error', data={x_pred: x_train})
print(f'Log likelihood is: {log_lik:0.2f}')
print(f'Mean squared error is: {mse:0.2f}')
PPC plots for mean, max and min.
ppc_plot(lambda xs, mus: tf.reduce_mean(xs[x_pred]), 'mean', x_pred, x_train)
ppc_plot(lambda xs, mus: tf.reduce_max(xs[x_pred]), 'max', x_pred, x_train)
ppc_plot(lambda xs, mus: tf.reduce_min(xs[x_pred]), 'min', x_pred, x_train)
In this notebook, we visually inspect the clustering assignments.
EXPERIMENT_NAME = '30-Nov-2017@16-58-38-GMM'
%matplotlib inline
import logging
import os
import numpy as np
from yass.neuralnet import NeuralNetDetector
from yass.config import Config
from neural_clustering.criticize import summarize_experiment
from neural_clustering.explore import (SpikeTrainExplorer,
RecordingExplorer)
from neural_clustering import config
import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (15, 15)
logging.basicConfig(level=logging.ERROR)
# load configuration files
cfg_yass = Config.from_yaml('../../yass_config/demo.yaml')
cfg = config.load('../../config.yaml')
# load data generated from yass
files = ['score', 'clear_index', 'spike_times', 'spike_train', 'spike_left', 'templates']
(score, clear_index,
spike_times, spike_train,
spike_left, templates) = [np.load(os.path.join(cfg['root'], 'yass/{}.npy'.format(f))) for f in files]
summarize_experiment(cfg, EXPERIMENT_NAME)
# load standarized recordings (these are raw recordings + filter + standarization)
path_to_recordings = os.path.join(cfg_yass.root, 'tmp/standarized.bin')
# load gemetry file (position for every electro)
path_to_geometry = os.path.join(cfg_yass.root, cfg_yass.geomFile)
# load projection matrix (to reduce dimensionality)
proj = NeuralNetDetector(cfg_yass).load_w_ae()
We load the clustering assignments from the experiment
clusters = np.load(os.path.join(cfg['root'], 'sessions', EXPERIMENT_NAME, 'clusters.npy'))
clear_spikes = np.load(os.path.join(cfg['root'], 'clear_spikes.npy'))
results = np.vstack([clear_spikes, clusters]).T
group_ids = np.unique(clusters)
These helper classes contain several functions for plotting results.
# initialize explorers, these objects implement functions for plotting
# the output from YASS
explorer_rec = RecordingExplorer(path_to_recordings,
path_to_geometry,
dtype='float64',
window_size=cfg_yass.spikeSize,
n_channels=cfg_yass.nChan,
neighbor_radius=cfg_yass.spatialRadius)
explorer_train = SpikeTrainExplorer(templates,
results,
explorer_rec,
proj)
Build templates from clusters: take every point in each cluster, get the original waveform (31 temporal observations x 7 channels) and average all elements. Then plot the templates.
def make_template(group_id):
wfs = explorer_train.waveforms_for_group(group_id=group_id, channels=range(7))
return np.average(wfs, axis=0)
templates_new = np.stack([make_template(group_id) for group_id in group_ids]).transpose(2, 1, 0)
explorer_train = SpikeTrainExplorer(templates_new,
results,
explorer_rec,
proj)
plt.rcParams['figure.figsize'] = (15, 15)
explorer_train.plot_templates(group_ids=group_ids)
Plot nehgboring clusters: get all the templates, compute the similarity among them (by comparing the squared difference in each temporal observation in each channel)
Then for every template, get the two most similar templates, find the cluster id for each and get the cluster elements for that cluster id.
Project all the points in the three clusters using LDA and plot.
plt.rcParams['figure.figsize'] = (15, 15)
explorer_train.plot_all_clusters(k=3)
Plot 3 similar clusters along with the template for each group.
explorer_train.plot_waveforms_and_clusters(group_id=0)
Probabilistic programming is great, it helped to quickly iterate experiments: trying out new models, inference algorithms and criticize every experiment
The research team I am working with has invested a lot of effort in evaluating our current clustering algorithm (MFM), such algorithm is a black-box for everyone but the original programmer, since it was implemented from scratch (not probabilistic progamming at all)
The output are the clustering assignments, so we do not have a way of using probabilistic approaches (like PPCs) to evaluate performance without first disentangling the black-box
Although we are not currently using Edward in the YASS package (but I hope we start using it soon), the results from this project are helping us improve YASS. If we manage to implement the MFM algorithm in Edward, we might will merge it into YASS
Using Edward has the potential to severely impact the project's development in many ways: first, we will be able to better criticize our models and second, we will be able to provide more models to YASS users that may work better for their datasets
I am excited about whay I learned in the class, hopefully YASS will be running on Edward soon